import torch
from pytorch_transformers import BertTokenizer
from utils.utils import load_text
from models.model_builder import BertSummarizer
from models.predictor import build_predictor
import os
import glob
import gdown
import csv
import codecs

# must be gdrive sharable link
MODEL_URL = 'https://drive.google.com/a/boastcapital.com/uc?id=1-IKVCtc4Q-BdZpjXc4s70_fRsWnjtYLr'
import rouge


def prepare_results(p, r, f, metric):
    return '\t{}:\t{}: {:5.2f}\t{}: {:5.2f}\t{}: {:5.2f}'.format(metric, 'P', 100.0 * p, 'R', 100.0 * r, 'F1', 100.0 * f)

def compute_rouge(hypothesis, references):
    for aggregator in ['Avg', 'Best']:
        print('Evaluation with {}'.format(aggregator))
        apply_avg = aggregator == 'Avg'
        apply_best = aggregator == 'Best'

        evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l', 'rouge-w'],
                               max_n=4,
                               limit_length=True,
                               length_limit=100,
                               length_limit_type='words',
                               apply_avg=apply_avg,
                               apply_best=apply_best,
                               alpha=0.5, # Default F1_score
                               weight_factor=1.2,
                               stemming=True)

        scores = evaluator.get_scores(hypothesis, references)

        for metric, results in sorted(scores.items(), key=lambda x: x[0]):
            if not apply_avg and not apply_best: # value is a type of list as we evaluate each summary vs each reference
                for hypothesis_id, results_per_ref in enumerate(results):
                    nb_references = len(results_per_ref['p'])
                    for reference_id in range(nb_references):
                        print('\tHypothesis #{} & Reference #{}: '.format(hypothesis_id, reference_id))
                        print('\t' + prepare_results(results_per_ref['p'][reference_id], results_per_ref['r'][reference_id], results_per_ref['f'][reference_id], metric))
                print()
            else:
                print(prepare_results(results['p'], results['r'], results['f'], metric))
        print()

class AbstractSummarizer(object):
    def __init__(self, model_path='cache/abs_bert_model.pt'):
        if not os.path.exists('cache'):
            os.mkdir('cache')

        # check if model is downloaded
        if not os.path.exists(model_path):
            print('Model not found in cache')
            self.download_model(model_path)

        # setup cache for bert model and tokenizer
        cache_dir = 'cache'
        if not os.path.exists(cache_dir):
            os.mkdir(cache_dir)

        checkpoint = torch.load(
            model_path, map_location=lambda storage, loc: storage)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = BertSummarizer(checkpoint, device, cache_dir)
        self.model.eval()

        tokenizer = BertTokenizer.from_pretrained(
            'bert-base-uncased', do_lower_case=True, cache_dir=cache_dir)
        self.predictor = build_predictor(tokenizer, self.model)

    def get_summary(self, texts):
        text_iter = load_text(texts)
        results =  self.predictor.translate(text_iter)
        print(results)
        return results

    @staticmethod
    def download_model(model_path):
        # download
        temp = 'cache/temp.zip'
        gdown.download(MODEL_URL, temp, quiet=False)

        # unzip
        print('unzipping file')
        gdown.extractall(temp)
        os.remove(temp)

        # rename
        os.rename(glob.glob('cache/*.pt')[0], model_path)


if __name__ == '__main__':
    summarizer = AbstractSummarizer()
    with codecs.open('./news_summary/news_summs_validation.csv', 'r','utf-8', errors='ignore') as f:
        reader = csv.reader(f)
        news = [row[1] for row in reader]
        news = news[:2]

    with codecs.open('./news_summary/news_summs_validation.csv', 'r','utf-8', errors='ignore') as f:
        reader = csv.reader(f)
        summaries = [row[2] for row in reader]
        summaries = summaries[:2]

    compute_rouge(summarizer.get_summary(news), summaries)


